{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.metrics import classification_report\n",
    "from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "种类： {'virginica', 'setosa', 'versicolor'}\n",
      "行列： (150, 5)\n",
      "   sepal_length  sepal_width  petal_length  petal_width species\n",
      "0           5.1          3.5           1.4          0.2  setosa\n",
      "1           4.9          3.0           1.4          0.2  setosa\n",
      "2           4.7          3.2           1.3          0.2  setosa\n",
      "3           4.6          3.1           1.5          0.2  setosa\n",
      "4           5.0          3.6           1.4          0.2  setosa\n",
      "数据分布：\n",
      " versicolor    50\n",
      "virginica     50\n",
      "setosa        50\n",
      "Name: species, dtype: int64\n"
     ]
    }
   ],
   "source": [
    "#加载iris\n",
    "data=sns.load_dataset('iris')\n",
    "\n",
    "print('种类：',set(data['species'].values))\n",
    "print('行列：',data.shape)\n",
    "print(data.head())\n",
    "print('数据分布：\\n',data[\"species\"].value_counts())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "#LDA:LinearDiscriminantAnalysis(线性判别分析)\n",
    "def iris_LDA(train_x,train_y,test_x,test_y):\n",
    "    clf=LDA(n_components=2)\n",
    "    clf.fit(train_x,train_y)\n",
    "    pre_y=clf.predict(test_x)\n",
    "    \n",
    "    report=classification_report(test_y,pre_y)\n",
    "    return report"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "训练(40%*150=60):测试(60%*150=90)=2:3\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "      setosa       1.00      1.00      1.00        30\n",
      "  versicolor       0.97      0.93      0.95        30\n",
      "   virginica       0.94      0.97      0.95        30\n",
      "\n",
      "    accuracy                           0.97        90\n",
      "   macro avg       0.97      0.97      0.97        90\n",
      "weighted avg       0.97      0.97      0.97        90\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print('训练(40%*150=60):测试(60%*150=90)=2:3')\n",
    "data_train=data.iloc[np.r_[0:20,50:70,100:120]]\n",
    "data_test=data.iloc[np.r_[20:50,70:100,120:150]]\n",
    "\n",
    "train_x=data_train.iloc[:,:4]#前四列作为训练数据\n",
    "train_y=data_train.iloc[:,4]#第五列作为训练分类\n",
    "\n",
    "test_x=data_test.iloc[:,:4]\n",
    "test_y=data_test.iloc[:,4]\n",
    "\n",
    "result=iris_LDA(train_x,train_y,test_x,test_y)\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "训练(60%*150=90):测试(40%*150=60)=3:2\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "      setosa       1.00      1.00      1.00        20\n",
      "  versicolor       0.95      0.95      0.95        20\n",
      "   virginica       0.95      0.95      0.95        20\n",
      "\n",
      "    accuracy                           0.97        60\n",
      "   macro avg       0.97      0.97      0.97        60\n",
      "weighted avg       0.97      0.97      0.97        60\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print('训练(60%*150=90):测试(40%*150=60)=3:2')\n",
    "data_train=data.iloc[np.r_[0:30,50:80,100:130]]\n",
    "data_test=data.iloc[np.r_[30:50,80:100,130:150]]\n",
    "\n",
    "train_x=data_train.iloc[:,:4]#前四列作为训练数据\n",
    "train_y=data_train.iloc[:,4]#第五列作为训练分类\n",
    "\n",
    "test_x=data_test.iloc[:,:4]\n",
    "test_y=data_test.iloc[:,4]\n",
    "\n",
    "result2=iris_LDA(train_x,train_y,test_x,test_y)\n",
    "print(result2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "训练(80%*150=120):测试(20%*150=30)=4:1\n",
      "              precision    recall  f1-score   support\n",
      "\n",
      "      setosa       1.00      1.00      1.00        10\n",
      "  versicolor       1.00      1.00      1.00        10\n",
      "   virginica       1.00      1.00      1.00        10\n",
      "\n",
      "    accuracy                           1.00        30\n",
      "   macro avg       1.00      1.00      1.00        30\n",
      "weighted avg       1.00      1.00      1.00        30\n",
      "\n"
     ]
    }
   ],
   "source": [
    "print('训练(80%*150=120):测试(20%*150=30)=4:1')\n",
    "data_train=data.iloc[np.r_[0:40,50:90,100:140]]\n",
    "data_test=data.iloc[np.r_[40:50,90:100,140:150]]\n",
    "\n",
    "train_x=data_train.iloc[:,:4]#前四列作为训练数据\n",
    "train_y=data_train.iloc[:,4]#第五列作为训练分类\n",
    "\n",
    "test_x=data_test.iloc[:,:4]\n",
    "test_y=data_test.iloc[:,4]\n",
    "\n",
    "result3=iris_LDA(train_x,train_y,test_x,test_y)\n",
    "print(result3)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
